{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn import datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "iris = datasets.load_iris()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = iris.data\n",
    "y = iris.target"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([112,  49,   0,  35,  61, 117,  10,  26,  74,   2, 121,  81,  70,\n",
       "        58,  34,  28,  88,  87,  40,  62,  75,  38,  21,  20,  57,  63,\n",
       "       134,   8, 128,  64,  37, 111, 116,  25,  17,  51,  65, 133,  19,\n",
       "        89,  79, 104,  77,  15, 106,  56, 122,  27,  48,  54,  60,  52,\n",
       "       126,  98,  90,  78, 136, 119,  30, 132,   9,  68,  14, 105, 148,\n",
       "        99,  22,  84, 147, 124,  80,  29, 113, 118, 144, 141, 140,  33,\n",
       "         6, 109, 149, 142,  12,  59,  32, 120,  93,  50, 138,  41,  43,\n",
       "        94,  47,  91,  39,  95,  83, 145,  71,  13, 125, 123,  36, 100,\n",
       "        23, 130,  46, 101, 110,  86,  72, 139,  85, 137,  18,  53,  45,\n",
       "        82,  67,  73, 115,   3,   5,  24, 102,  42,  31, 107,  16, 146,\n",
       "       103, 143,   7,  97,  69, 127,  92, 135,  11, 131, 108,   4,  44,\n",
       "       129, 114,   1,  76,  96,  55,  66])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 对索引进行随机排列\n",
    "shuffle_indexes = np.random.permutation(len(X))\n",
    "shuffle_indexes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "30"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_ratio = 0.2\n",
    "test_size = int(len(X)*test_ratio)\n",
    "test_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_indexes = shuffle_indexes[:test_size]\n",
    "train_indexes = shuffle_indexes[test_size:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = X[train_indexes]\n",
    "y_train = y[train_indexes]\n",
    "\n",
    "X_test = X[test_indexes]\n",
    "y_test = y[test_indexes]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 从py文件导入函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from c2_model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(120, 4) (30, 4)\n",
      "(120,) (30,)\n"
     ]
    }
   ],
   "source": [
    "X_train, y_train, X_test, y_test = train_test_split(X, y)\n",
    "print(X_train.shape, X_test.shape)\n",
    "print(y_train.shape, y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from c1_KNNClassifier import KNNClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "my_knn_cls = KNNClassifier(k=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNN(k = 3)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "my_knn_cls.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1, 0, 1, 0, 2, 1, 1, 2, 0, 0, 2, 2, 2, 2, 2, 0, 0, 0, 1, 2, 1, 1,\n",
       "       0, 0, 0, 2, 0, 2, 1, 1])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_predict = my_knn_cls.predict(X_test)\n",
    "y_predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9666666666666667"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accuracy = sum(y_predict == y_test)/len(y_test)\n",
    "accuracy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# sklearn中的train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(120, 4) (30, 4)\n",
      "(120,) (30,)\n"
     ]
    }
   ],
   "source": [
    "X_train, X_test, y_train, y_test =train_test_split(X, y , test_size=0.2, random_state=666)\n",
    "print(X_train.shape, X_test.shape)\n",
    "print(y_train.shape, y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
